import pandas as pd
from fbprophet import Prophet
from fbprophet.plot import add_changepoints_to_plot
import numpy as np
import matplotlib.pyplot as plt

# 读入数据集
live_df = pd.read_csv('data/train.csv')
print(live_df.head())

# 定义节假日
chinese_holiday = pd.DataFrame({
    'holiday': 'Lunar_festivals',
    'ds': pd.to_datetime(['2023-01-21', '2023-01-22', '2023-01-23',
                          '2023-01-24', '2023-01-25', '2023-01-26',
                          '2023-01-27', '2023-04-05', '2023-06-22',
                          '2023-06-23', '2023-06-24', '2023-09-29',
                          '2023-09-30']),
    'lower_window': 0,
    'upper_window': 1,
})
china_holiday = pd.DataFrame({
    'holiday': 'china',
    'ds': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-05-01',
                          '2023-05-02', '2023-05-03', '2023-10-01',
                          '2023-10-02', '2023-10-03', '2023-10-04',
                          '2023-10-05', '2023-10-06']),
    'lower_window': 0,
    'upper_window': 1,
})
holidays = pd.concat((chinese_holiday, china_holiday))

model = Prophet(holidays=holidays)
model.fit(live_df)

future = model.make_future_dataframe(periods=365, freq='1440m', include_history=False)
future.tail(365)

forecast = model.predict(future)
forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(365)

fig1 = model.plot(forecast)

fig = model.plot_components(forecast)

plt.show()
